import torch

from MyModel import CIFAR10Model
from torchvision.transforms import Compose, ToTensor, Resize
from PIL import Image

img_path = "./dataset/img.png"
img = Image.open(img_path)
# png为4个通道，因此需要转为三通道
img = img.convert("RGB")
print(img)

img_compose = Compose([
    Resize((32, 32)),
    ToTensor(),
])

img_tensor = img_compose(img)
img_tensor = img_tensor.reshape(1, 3, 32, 32)

# 如果是在GPU上训练的模型，需要转为Cpu上的模型
model = torch.load("./my_model_6139.pth", map_location=torch.device('cpu'))
index = model(img_tensor).argmax()
print(index)